15. K-means Segmentation#
!pip install moviepy scikit-image scikit-learn
Show code cell source
import os
import cv2
import numpy as np
import requests
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import plotly.express as px
from moviepy.editor import VideoFileClip, ImageSequenceClip
from sklearn.cluster import KMeans
from collections import Counter
from skimage.segmentation import slic
from skimage.util import img_as_float
from IPython.display import HTML, display
from base64 import b64encode
import ipywidgets as widgets
import pandas as pd
WARNING:py.warnings:/usr/local/lib/python3.10/dist-packages/moviepy/video/io/sliders.py:61: SyntaxWarning: "is" with a literal. Did you mean "=="?
if event.key is 'enter':
def download_video(video_url, save_path):
response = requests.get(video_url, stream=True)
if response.status_code == 200:
with open(save_path, 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
f.write(chunk)
print(f"Video downloaded successfully and saved to: {save_path}")
else:
print(f"Failed to download video. Status code: {response.status_code}")
# Function to extract frames from a local video file
def extract_frames_from_video(video_path, output_dir, frame_rate=10, width=1024, height=1024):
# Create the output directory if it doesn't exist
os.makedirs(output_dir, exist_ok=True)
# Load the video using moviepy
clip = VideoFileClip(video_path)
# Extract frames at the specified rate and resolution
for i, frame in enumerate(clip.iter_frames(fps=frame_rate)):
# Resize the frame
resized_frame = cv2.resize(frame, (width, height))
# Save the frame
frame_path = os.path.join(output_dir, f'frame_{i:04d}.png')
cv2.imwrite(frame_path, resized_frame)
print(f"Frames extracted and saved to: {output_dir}")
# Function to apply KMeans clustering to an image
def apply_kmeans(image, n_clusters, kmeans_model=None):
pixel_values = image.reshape((-1, 3))
pixel_values = np.float32(pixel_values)
if kmeans_model is None:
kmeans = KMeans(n_clusters=n_clusters, random_state=42)
kmeans.fit(pixel_values)
else:
kmeans = kmeans_model
labels = kmeans.predict(pixel_values)
segmented_image = kmeans.cluster_centers_[labels]
segmented_image = segmented_image.reshape(image.shape)
segmented_image = np.uint8(segmented_image)
return segmented_image, labels, kmeans
# Function to preprocess an image using CLAHE and HSV conversion
def preprocess_image(image):
image_cropped = image[-400:, :, :]
hsv_image = cv2.cvtColor(image_cropped, cv2.COLOR_BGR2HSV)
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
hsv_image[:, :, 2] = clahe.apply(hsv_image[:, :, 2])
preprocessed_image = cv2.cvtColor(hsv_image, cv2.COLOR_HSV2BGR)
return preprocessed_image
# Function to segment an image using SLIC algorithm
def segment_image(image, n_segments):
image_float = img_as_float(image)
segments = slic(image_float, n_segments=n_segments, compactness=10, start_label=0)
return segments
# Function to process an image, apply KMeans, save output, and add cluster information
def process_and_save_image(image, kmeans_model, filename, output_dir):
image_cropped = preprocess_image(image)
segments = segment_image(image_cropped, n_segments=500)
segmented_image, labels, _ = apply_kmeans(image_cropped, n_clusters=4, kmeans_model=kmeans_model)
label_counts = Counter(labels)
cluster_info = []
total_pixels = image_cropped.shape[0] * image_cropped.shape[1]
cluster_hsv_values = kmeans_model.cluster_centers_
for i in range(4):
cluster_percentage = (label_counts[i] / total_pixels) * 100
cluster_hsv = cluster_hsv_values[i]
cluster_info.append(f"Cluster {i}: {label_counts[i]} pixels ({cluster_percentage:.2f}%) - HSV: ({cluster_hsv[0]:.2f}, {cluster_hsv[1]:.2f}, {cluster_hsv[2]:.2f})")
segmented_image_bgr = cv2.cvtColor(segmented_image, cv2.COLOR_HSV2BGR)
output_image_path = os.path.join(output_dir, filename)
cv2.imwrite(output_image_path, segmented_image_bgr)
output_txt_path = os.path.join(output_dir, filename.rsplit('.', 1)[0] + '_clusters.txt')
with open(output_txt_path, 'w') as f:
f.write("\n".join(cluster_info))
for i, info in enumerate(cluster_info):
text_position = (10, 30 + i * 20)
cv2.putText(segmented_image_bgr, info, text_position, cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2)
return segmented_image_bgr
# Function to process all frames and segment them
def process_frames(output_dir):
frames = os.listdir(output_dir)
if not frames:
print("No frames available for processing.")
return
# Process the first image to initialize KMeans model
first_image_path = os.path.join(output_dir, frames[0])
first_image = cv2.imread(first_image_path)
first_image_cropped = preprocess_image(first_image)
segmented_image, labels, kmeans_model = apply_kmeans(first_image_cropped, n_clusters=4)
# Process all frames and store processed frames in a list
processed_frames_list = [] # Create a list to store processed frames
for filename in frames:
if filename.endswith(('.png', '.jpg', '.jpeg')):
image_path = os.path.join(output_dir, filename)
image = cv2.imread(image_path)
processed_frame = process_and_save_image(image, kmeans_model, filename, output_dir)
processed_frames_list.append(processed_frame) # Append processed frame to the list
print("Segmentation completed, results saved.")
return processed_frames_list # Return the list of processed frames
# Function to create a video from processed frames
def create_video_from_frames(frames, output_video_path, fps=10):
clip = ImageSequenceClip([cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frames], fps=fps)
clip.write_videofile(output_video_path, codec='libx264')
print(f"Video saved to: {output_video_path}")
video_url = "https://github.com/atticus-carter/cv/raw/refs/heads/main/videos/output_video_8.avi"
video_path = "/content/2022SHRSubset.avi"
frames_output_dir = "/content/frames" # Change to your desired output directory
segmented_video_output_path = "/content/segmented_videos/segmented_video.mp4"
os.makedirs(os.path.dirname(segmented_video_output_path), exist_ok=True)
os.makedirs(frames_output_dir, exist_ok=True)
os.makedirs("/content/segmented_videos", exist_ok=True)
# Download the video
download_video(video_url, video_path)
# Extract frames from local video
extract_frames_from_video(video_path, frames_output_dir)
# Process frames for segmentation
processed_frames = process_frames(frames_output_dir)
# Create video from processed frames
create_video_from_frames(processed_frames, segmented_video_output_path)
Show code cell source
def display_cluster_colors_with_image(cluster_file_path, image_path, video_path):
"""Displays the colors of the clusters visually from a cluster text file,
along with the original image and an original video frame clip using Plotly's imshow."""
# Extract and display the first frame from the video
clip = VideoFileClip(video_path)
first_frame = clip.get_frame(0) # Get the first frame
# Cut it down to the bottom 400 pixels
first_frame_cropped = first_frame[-400:, :, :]
# Save the cropped image
os.makedirs("/content", exist_ok=True)
cv2.imwrite("/content/cropped_image.png", first_frame_cropped)
# Convert BGR to RGB for display with plotly
image_rgb = cv2.cvtColor(first_frame_cropped, cv2.COLOR_BGR2RGB)
# Display the original image clip using Plotly
fig_original = px.imshow(image_rgb)
fig_original.update_layout(title="Original Image Clip")
fig_original.show()
cluster_file = os.path.join(cluster_file_path, "frame_0000_clusters.txt")
with open(cluster_file, 'r') as f:
lines = f.readlines()
hsv_colors = []
for line in lines:
if 'HSV' in line:
hsv_str = line.split('HSV: ')[1].strip('()\n').split(',')
hsv_colors.append([float(x.strip()) for x in hsv_str])
# Load the image
image_file = os.path.join(image_path, "frame_0000.png")
image2 = cv2.imread(image_file)
image_rgb_clust = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB) # Convert to RGB
# Create color swatches using matplotlib
fig, ax = plt.subplots(1, len(hsv_colors), figsize=(5, 2))
for i, hsv_color in enumerate(hsv_colors):
bgr_color = cv2.cvtColor(np.uint8([[hsv_color]]), cv2.COLOR_HSV2BGR)[0][0]
rgb_color = bgr_color[::-1]
rect = patches.Rectangle((0, 0), 1, 1, facecolor=tuple(rgb_color / 255.0))
ax[i].add_patch(rect)
ax[i].axis('off')
ax[i].set_title(f'Cluster {i}')
plt.tight_layout()
# Display the clustered image using Plotly's imshow
fig_image = px.imshow(image_rgb_clust)
fig_image.update_layout(title="Clustered Image")
# Show both plots (color swatches and image)
fig_image.show()
plt.show()
cluster_file_path = frames_output_dir
image_path = frames_output_dir
video_path = video_path
frames_output_dir = frames_output_dir
segmented_video_output_path = segmented_video_output_path
display_cluster_colors_with_image(cluster_file_path, image_path, video_path)
def show_video(video_path, width=600):
mp4 = open(video_path,'rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
return HTML("""
<video width="{0}" controls>
<source src="{1}" type="video/mp4">
</video>
""".format(width, data_url))
show_video(segmented_video_output_path)